{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "原文代码作者：https://github.com/wzyonggege/statistical-learning-method"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第7章 支持向量机"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----\n",
    "分离超平面：$w^Tx+b=0$\n",
    "\n",
    "点到直线距离：$r=\\frac{|w^Tx+b|}{||w||_2}$\n",
    "\n",
    "$||w||_2$为2-范数：$||w||_2=\\sqrt[2]{\\sum^m_{i=1}w_i^2}$\n",
    "\n",
    "直线为超平面，样本可表示为：\n",
    "\n",
    "$w^Tx+b\\ \\geq+1$\n",
    "\n",
    "$w^Tx+b\\ \\leq+1$\n",
    "\n",
    "#### margin：\n",
    "\n",
    "**函数间隔**：$label(w^Tx+b)\\ or\\ y_i(w^Tx+b)$\n",
    "\n",
    "**几何间隔**：$r=\\frac{label(w^Tx+b)}{||w||_2}$，当数据被正确分类时，几何间隔就是点到超平面的距离\n",
    "\n",
    "为了求几何间隔最大，SVM基本问题可以转化为求解:($\\frac{r^*}{||w||}$为几何间隔，(${r^*}$为函数间隔)\n",
    "\n",
    "$$\\max\\ \\frac{r^*}{||w||}$$\n",
    "\n",
    "$$(subject\\ to)\\ y_i({w^T}x_i+{b})\\geq {r^*},\\ i=1,2,..,m$$\n",
    "\n",
    "分类点几何间隔最大，同时被正确分类。但这个方程并非凸函数求解，所以要先①将方程转化为凸函数，②用拉格朗日乘子法和KKT条件求解对偶问题。\n",
    "\n",
    "①转化为凸函数：\n",
    "\n",
    "先令${r^*}=1$，方便计算（参照衡量，不影响评价结果）\n",
    "\n",
    "$$\\max\\ \\frac{1}{||w||}$$\n",
    "\n",
    "$$s.t.\\ y_i({w^T}x_i+{b})\\geq {1},\\ i=1,2,..,m$$\n",
    "\n",
    "再将$\\max\\ \\frac{1}{||w||}$转化成$\\min\\ \\frac{1}{2}||w||^2$求解凸函数，1/2是为了求导之后方便计算。\n",
    "\n",
    "$$\\min\\ \\frac{1}{2}||w||^2$$\n",
    "\n",
    "$$s.t.\\ y_i(w^Tx_i+b)\\geq 1,\\ i=1,2,..,m$$\n",
    "\n",
    "②用拉格朗日乘子法和KKT条件求解最优值：\n",
    "\n",
    "$$\\min\\ \\frac{1}{2}||w||^2$$\n",
    "\n",
    "$$s.t.\\ -y_i(w^Tx_i+b)+1\\leq 0,\\ i=1,2,..,m$$\n",
    "\n",
    "整合成：\n",
    "\n",
    "$$L(w, b, \\alpha) = \\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$$\n",
    "\n",
    "推导：$\\min\\ f(x)=\\min \\max\\ L(w, b, \\alpha)\\geq \\max \\min\\ L(w, b, \\alpha)$\n",
    "\n",
    "根据KKT条件：\n",
    "\n",
    "$$\\frac{\\partial }{\\partial w}L(w, b, \\alpha)=w-\\sum\\alpha_iy_ix_i=0,\\ w=\\sum\\alpha_iy_ix_i$$\n",
    "\n",
    "$$\\frac{\\partial }{\\partial b}L(w, b, \\alpha)=\\sum\\alpha_iy_i=0$$\n",
    "\n",
    "带入$ L(w, b, \\alpha)$\n",
    "\n",
    "$\\min\\  L(w, b, \\alpha)=\\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$\n",
    "\n",
    "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^Tw-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i-b\\sum^m_{i=1}\\alpha_iy_i+\\sum^m_{i=1}\\alpha_i$\n",
    "\n",
    "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^T\\sum\\alpha_iy_ix_i-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i+\\sum^m_{i=1}\\alpha_i$\n",
    "\n",
    "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\alpha_iy_iw^Tx_i$\n",
    "\n",
    "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)$\n",
    "\n",
    "再把max问题转成min问题：\n",
    "\n",
    "$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$\n",
    "\n",
    "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n",
    "\n",
    "$ \\alpha_i \\geq 0,i=1,2,...,m$\n",
    "\n",
    "以上为SVM对偶问题的对偶形式\n",
    "\n",
    "-----\n",
    "#### kernel\n",
    "\n",
    "在低维空间计算获得高维空间的计算结果，也就是说计算结果满足高维（满足高维，才能说明高维下线性可分）。\n",
    "\n",
    "#### soft margin & slack variable\n",
    "\n",
    "引入松弛变量$\\xi\\geq0$，对应数据点允许偏离的functional margin 的量。\n",
    "\n",
    "目标函数：$\\min\\ \\frac{1}{2}||w||^2+C\\sum\\xi_i\\qquad s.t.\\ y_i(w^Tx_i+b)\\geq1-\\xi_i$ \n",
    "\n",
    "对偶问题：\n",
    "\n",
    "$$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$$\n",
    "\n",
    "$$s.t.\\ C\\geq\\alpha_i \\geq 0,i=1,2,...,m\\quad \\sum^m_{i=1}\\alpha_iy_i=0,$$\n",
    "\n",
    "-----\n",
    "\n",
    "#### Sequential Minimal Optimization\n",
    "\n",
    "首先定义特征到结果的输出函数：$u=w^Tx+b$.\n",
    "\n",
    "因为$w=\\sum\\alpha_iy_ix_i$\n",
    "\n",
    "有$u=\\sum y_i\\alpha_iK(x_i, x)-b$\n",
    "\n",
    "\n",
    "----\n",
    "\n",
    "$\\max \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\sum^m_{j=1}\\alpha_i\\alpha_jy_iy_j<\\phi(x_i)^T,\\phi(x_j)>$\n",
    "\n",
    "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n",
    "\n",
    "$ \\alpha_i \\geq 0,i=1,2,...,m$\n",
    "\n",
    "-----\n",
    "参考资料：\n",
    "\n",
    "[1] :[Lagrange Multiplier and KKT](http://blog.csdn.net/xianlingmao/article/details/7919597)\n",
    "\n",
    "[2] :[推导SVM](https://my.oschina.net/dfsj66011/blog/517766)\n",
    "\n",
    "[3] :[机器学习算法实践-支持向量机(SVM)算法原理](http://pytlab.org/2017/08/15/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E5%AE%9E%E8%B7%B5-%E6%94%AF%E6%8C%81%E5%90%91%E9%87%8F%E6%9C%BA-SVM-%E7%AE%97%E6%B3%95%E5%8E%9F%E7%90%86/)\n",
    "\n",
    "[4] :[Python实现SVM](http://blog.csdn.net/wds2006sdo/article/details/53156589)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import  train_test_split\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建数据集\n",
    "def create_data():\n",
    "    # 加载鸢尾花数据集\n",
    "    iris = load_iris()\n",
    "    df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
    "    df['label'] = iris.target\n",
    "    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n",
    "    # 取前100个数据\n",
    "    data = np.array(df.iloc[:100, [0, 1, -1]])\n",
    "    for i in range(len(data)):\n",
    "        if data[i,-1] == 0:\n",
    "            data[i,-1] = -1\n",
    "    return data[:,:2], data[:,-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 分割数据集，以便交叉验证，分为75%的训练数据，25%的测试数据\n",
    "X, y = create_data()\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x1a9cf5c0>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAGcZJREFUeJzt3X+MXWWdx/H3d4dZOiowoYwrzJQtWNMo0LUwgqQJccHdaq2lQba08VeVlV2DCwYXI4agNiZgSPDHkmj4kQWELXYrlsLyYxGW+CNSMwVs11YiCtoZ2GUotshaoB2++8e9Q2fu3Jl7n3vvmfs8z/28kqZzz316+n3O0S+353zOc83dERGRvPxZuwsQEZHWU3MXEcmQmruISIbU3EVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7iIiGTqk3oFm1gUMASPuvrzivbXA1cBIedO17n7DTPs76qijfP78+UHFioh0uq1btz7v7n21xtXd3IGLgZ3A4dO8/z13/0y9O5s/fz5DQ0MBf72IiJjZ7+oZV9dlGTMbAD4AzPhpXERE4lDvNfdvAJ8HXpthzIfMbJuZbTSzedUGmNkFZjZkZkOjo6OhtYqISJ1qNnczWw485+5bZxh2FzDf3RcBPwRurjbI3a9z90F3H+zrq3nJSEREGlTPNfclwAozWwbMAQ43s1vd/SPjA9x994Tx1wNfa22ZIiKts3//foaHh3n55ZfbXcq05syZw8DAAN3d3Q39+ZrN3d0vAy4DMLP3AP88sbGXtx/t7s+WX66gdONVRCRKw8PDHHbYYcyfPx8za3c5U7g7u3fvZnh4mOOOO66hfTScczezdWa2ovzyIjP7pZn9ArgIWNvofkVEivbyyy8zd+7cKBs7gJkxd+7cpv5lERKFxN0fBh4u/3zFhO2vf7oXyc2mx0a4+v4neGbPPo7p7eHSpQtZubi/3WVJk2Jt7OOarS+ouYt0mk2PjXDZHdvZt38MgJE9+7jsju0AavASNS0/IDKDq+9/4vXGPm7f/jGuvv+JNlUkubjvvvtYuHAhCxYs4Kqrrmr5/tXcRWbwzJ59QdtF6jE2NsaFF17Ivffey44dO1i/fj07duxo6d+hyzIiMzimt4eRKo38mN6eNlQj7dLq+y4///nPWbBgAccffzwAq1ev5s477+Qd73hHq0rWJ3eRmVy6dCE93V2TtvV0d3Hp0oVtqkhm2/h9l5E9+3AO3nfZ9NhIzT87nZGREebNO/gg/8DAACMjje+vGjV3kRmsXNzPleecRH9vDwb09/Zw5Tkn6WZqBynivou7T9nW6vSOLsuI1LBycb+aeQcr4r7LwMAAu3btev318PAwxxxzTMP7q0af3EVEZjDd/ZVm7ru8613v4te//jVPPfUUr776KrfffjsrVqyo/QcDqLmLiMygiPsuhxxyCNdeey1Lly7l7W9/O6tWreKEE05ottTJf0dL9yYikpnxS3Ktfkp52bJlLFu2rBUlVqXmLiJSQ4r3XXRZRkQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLtnY9NgIS656iOO+8B8sueqhptb+ECnaJz/5Sd785jdz4oknFrJ/NXfJQhGLO4kUae3atdx3332F7V/NXbKgL9WQQm3bAF8/Eb7cW/p924amd3nGGWdw5JFHtqC46vQQk2RBX6ohhdm2Ae66CPaX/7e0d1fpNcCiVe2rqwZ9cpcsFLG4kwgAD6472NjH7d9X2h4xNXfJgr5UQwqzdzhseyR0WUayUNTiTiIcMVC6FFNte8TU3CUbKS7uJAk464rJ19wBuntK25uwZs0aHn74YZ5//nkGBgb4yle+wvnnn99ksQepuUvTWv3lwSJRGb9p+uC60qWYIwZKjb3Jm6nr169vQXHTU3OXpozny8djiOP5ckANXvKxaFXUyZhqdENVmqJ8uUic1NylKcqXS6rcvd0lzKjZ+tTcpSnKl0uK5syZw+7du6Nt8O7O7t27mTNnTsP70DV3acqlSxdOuuYOypdL/AYGBhgeHmZ0dLTdpUxrzpw5DAw0HrdUc5emKF8uKeru7ua4445rdxmFqru5m1kXMASMuPvyivcOBW4BTgF2A+e5+9MtrFMipny5SHxCPrlfDOwEDq/y3vnAH9x9gZmtBr4GnNeC+kSSosy/xKKuG6pmNgB8ALhhmiFnAzeXf94InGVm1nx5IunQmvISk3rTMt8APg+8Ns37/cAuAHc/AOwF5jZdnUhClPmXmNRs7ma2HHjO3bfONKzKtikZIzO7wMyGzGwo5rvUIo1Q5l9iUs8n9yXACjN7GrgdONPMbq0YMwzMAzCzQ4AjgBcqd+Tu17n7oLsP9vX1NVW4SGyU+ZeY1Gzu7n6Zuw+4+3xgNfCQu3+kYthm4OPln88tj4nz6QCRgmhNeYlJwzl3M1sHDLn7ZuBG4Ltm9iSlT+yrW1SfSDKU+ZeYWLs+YA8ODvrQ0FBb/m4RkVSZ2VZ3H6w1Tk+oSrQu37Sd9Vt2MeZOlxlrTpvHV1ee1O6yRJKg5i5RunzTdm595Pevvx5zf/21GrxIbVoVUqK0fkuV76ycYbuITKbmLlEam+Ze0HTbRWQyNXeJUtc0q1dMt11EJlNzlyitOW1e0HYRmUw3VCVK4zdNlZYRaYxy7iIiCVHOXZry4et/xk9/c3B5oCVvPZLbPnV6GytqH63RLinSNXeZorKxA/z0Ny/w4et/1qaK2kdrtEuq1NxlisrGXmt7zrRGu6RKzV1kBlqjXVKl5i4yA63RLqlSc5cplrz1yKDtOdMa7ZIqNXeZ4rZPnT6lkXdqWmbl4n6uPOck+nt7MKC/t4crzzlJaRmJnnLuIiIJUc5dmlJUtjtkv8qXizROzV2mGM92j0cAx7PdQFPNNWS/RdUg0il0zV2mKCrbHbJf5ctFmqPmLlMUle0O2a/y5SLNUXOXKYrKdofsV/lykeaoucsURWW7Q/arfLlIc3RDVaYYv2HZ6qRKyH6LqkGkUyjnLiKSEOXcCxZDBju0hhhqFpHZoebegBgy2KE1xFCziMwe3VBtQAwZ7NAaYqhZRGaPmnsDYshgh9YQQ80iMnvU3BsQQwY7tIYYahaR2aPm3oAYMtihNcRQs4jMHt1QbUAMGezQGmKoWURmT82cu5nNAX4EHErpPwYb3f1LFWPWAlcD418Jf6273zDTfpVzFxEJ18qc+yvAme7+kpl1Az8xs3vd/ZGKcd9z9880UqzMjss3bWf9ll2MudNlxprT5vHVlSc1PTaW/HwsdYjEoGZz99JH+5fKL7vLv9rzWKs07PJN27n1kd+//nrM/fXXlU07ZGws+flY6hCJRV03VM2sy8weB54DHnD3LVWGfcjMtpnZRjOb19IqpWnrt+yqe3vI2Fjy87HUIRKLupq7u4+5+zuBAeBUMzuxYshdwHx3XwT8ELi52n7M7AIzGzKzodHR0WbqlkBj09xbqbY9ZGws+flY6hCJRVAU0t33AA8D76vYvtvdXym/vB44ZZo/f527D7r7YF9fXwPlSqO6zOreHjI2lvx8LHWIxKJmczezPjPrLf/cA7wX+FXFmKMnvFwB7GxlkdK8NadVv1JWbXvI2Fjy87HUIRKLetIyRwM3m1kXpf8YbHD3u81sHTDk7puBi8xsBXAAeAFYW1TB0pjxG6H1JGBCxsaSn4+lDpFYaD13EZGEaD33ghWVqQ7Jlxe575D5pXgskrNtAzy4DvYOwxEDcNYVsGhVu6uSiKm5N6CoTHVIvrzIfYfML8VjkZxtG+Cui2B/Ofmzd1fpNajBy7S0cFgDispUh+TLi9x3yPxSPBbJeXDdwcY+bv++0naRaai5N6CoTHVIvrzIfYfML8VjkZy9w2HbRVBzb0hRmeqQfHmR+w6ZX4rHIjlHDIRtF0HNvSFFZapD8uVF7jtkfikei+ScdQV0V/zHsruntF1kGrqh2oCiMtUh+fIi9x0yvxSPRXLGb5oqLSMBlHMXEUmIcu4yRQzZdUmc8vbJUHPvEDFk1yVxytsnRTdUO0QM2XVJnPL2SVFz7xAxZNclccrbJ0XNvUPEkF2XxClvnxQ19w4RQ3ZdEqe8fVJ0Q7VDxJBdl8Qpb58U5dxFRBKinHtZUXntkP3Gsi65suuRyT0znvv8QrThWGTd3IvKa4fsN5Z1yZVdj0zumfHc5xeiTcci6xuqReW1Q/Yby7rkyq5HJvfMeO7zC9GmY5F1cy8qrx2y31jWJVd2PTK5Z8Zzn1+INh2LrJt7UXntkP3Gsi65suuRyT0znvv8QrTpWGTd3IvKa4fsN5Z1yZVdj0zumfHc5xeiTcci6xuqReW1Q/Yby7rkyq5HJvfMeO7zC9GmY6Gcu4hIQpRzL5jy8yKJuPsS2HoT+BhYF5yyFpZf0/x+I8/xq7k3QPl5kUTcfQkM3XjwtY8dfN1Mg08gx5/1DdWiKD8vkoitN4Vtr1cCOX419wYoPy+SCB8L216vBHL8au4NUH5eJBHWFba9Xgnk+NXcG6D8vEgiTlkbtr1eCeT4dUO1AcrPiyRi/KZpq9MyCeT4lXMXEUlIy3LuZjYH+BFwaHn8Rnf/UsWYQ4FbgFOA3cB57v50A3XXFJovT20N85Dseu7HotAccUj2uag6ipxf5BnspoTOLedjMYN6Lsu8Apzp7i+ZWTfwEzO7190fmTDmfOAP7r7AzFYDXwPOa3Wxofny1NYwD8mu534sCs0Rh2Sfi6qjyPklkMFuWOjccj4WNdS8oeolL5Vfdpd/VV7LORu4ufzzRuAss9bHNkLz5amtYR6SXc/9WBSaIw7JPhdVR5HzSyCD3bDQueV8LGqoKy1jZl1m9jjwHPCAu2+pGNIP7AJw9wPAXmBulf1cYGZDZjY0OjoaXGxovjy1NcxDsuu5H4tCc8Qh2eei6ihyfglksBsWOrecj0UNdTV3dx9z93cCA8CpZnZixZBqn9KndCR3v87dB919sK+vL7jY0Hx5amuYh2TXcz8WheaIQ7LPRdVR5PwSyGA3LHRuOR+LGoJy7u6+B3gYeF/FW8PAPAAzOwQ4AnihBfVNEpovT20N85Dseu7HotAccUj2uag6ipxfAhnshoXOLedjUUM9aZk+YL+77zGzHuC9lG6YTrQZ+DjwM+Bc4CEvIGMZmi9PbQ3zkOx67sei0BxxSPa5qDqKnF8CGeyGhc4t52NRQ82cu5ktonSztIvSJ/0N7r7OzNYBQ+6+uRyX/C6wmNIn9tXu/tuZ9qucu4hIuJbl3N19G6WmXbn9igk/vwz8XWiRIiJSjOyXH0juwR2ZHSEPtsTwEEyRD+6k9pBWDOcjAVk39+Qe3JHZEfJgSwwPwRT54E5qD2nFcD4SkfWqkMk9uCOzI+TBlhgeginywZ3UHtKK4XwkIuvmntyDOzI7Qh5sieEhmCIf3EntIa0Yzkcism7uyT24I7Mj5MGWGB6CKfLBndQe0orhfCQi6+ae3IM7MjtCHmyJ4SGYIh/cSe0hrRjORyKybu4rF/dz5Tkn0d/bgwH9vT1cec5Jupna6Ratgg9+C46YB1jp9w9+q/oNuZCxMdQbOr6o+aW23wzpyzpERBLSsoeYRDpeyBd7xCK1mmPJrsdSRwuouYvMJOSLPWKRWs2xZNdjqaNFsr7mLtK0kC/2iEVqNceSXY+ljhZRcxeZScgXe8QitZpjya7HUkeLqLmLzCTkiz1ikVrNsWTXY6mjRdTcRWYS8sUesUit5liy67HU0SJq7iIzWX4NDJ5/8FOvdZVex3hjclxqNceSXY+ljhZRzl1EJCHKucvsSTEbXFTNReXLUzzG0lZq7tKcFLPBRdVcVL48xWMsbadr7tKcFLPBRdVcVL48xWMsbafmLs1JMRtcVM1F5ctTPMbSdmru0pwUs8FF1VxUvjzFYyxtp+YuzUkxG1xUzUXly1M8xtJ2au7SnBSzwUXVXFS+PMVjLG2nnLuISELqzbnrk7vkY9sG+PqJ8OXe0u/bNsz+fouqQSSQcu6Sh6Ky4CH7VR5dIqJP7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iUjPnbmbzgFuAtwCvAde5+zcrxrwHuBN4qrzpDnef8S6Scu4iIuFauZ77AeBz7v6omR0GbDWzB9x9R8W4H7v78kaKlQiluH54SM0pzi8GOm7JqNnc3f1Z4Nnyz380s51AP1DZ3CUXKea1lUcvno5bUoKuuZvZfGAxsKXK26eb2S/M7F4zO6EFtUm7pJjXVh69eDpuSan7CVUzexPwfeCz7v5ixduPAn/p7i+Z2TJgE/C2Kvu4ALgA4Nhjj224aClYinlt5dGLp+OWlLo+uZtZN6XGfpu731H5vru/6O4vlX++B+g2s6OqjLvO3QfdfbCvr6/J0qUwKea1lUcvno5bUmo2dzMz4EZgp7tXXbvUzN5SHoeZnVre7+5WFiqzKMW8tvLoxdNxS0o9l2WWAB8FtpvZ4+VtXwSOBXD37wDnAp82swPAPmC1t2stYWne+M2xlFIRITWnOL8Y6LglReu5i4gkpJU5d4mVMseT3X0JbL2p9IXU1lX6ertmvwVJJFFq7qlS5niyuy+BoRsPvvaxg6/V4KUDaW2ZVClzPNnWm8K2i2ROzT1VyhxP5mNh20Uyp+aeKmWOJ7OusO0imVNzT5Uyx5OdsjZsu0jm1NxTpbXDJ1t+DQyef/CTunWVXutmqnQo5dxFRBKinHsDNj02wtX3P8Eze/ZxTG8Ply5dyMrF/e0uq3Vyz8XnPr8Y6BgnQ829bNNjI1x2x3b27S+lK0b27OOyO7YD5NHgc8/F5z6/GOgYJ0XX3Muuvv+J1xv7uH37x7j6/ifaVFGL5Z6Lz31+MdAxToqae9kze/YFbU9O7rn43OcXAx3jpKi5lx3T2xO0PTm55+Jzn18MdIyTouZedunShfR0T37gpae7i0uXLmxTRS2Wey4+9/nFQMc4KbqhWjZ+0zTbtEzua3HnPr8Y6BgnRTl3EZGE1Jtz12UZkRRs2wBfPxG+3Fv6fduGNPYtbaPLMiKxKzJfrux6tvTJXSR2RebLlV3Plpq7SOyKzJcru54tNXeR2BWZL1d2PVtq7iKxKzJfrux6ttTcRWJX5Nr9+l6AbCnnLiKSEOXcRUQ6mJq7iEiG1NxFRDKk5i4ikiE1dxGRDKm5i4hkSM1dRCRDau4iIhmq2dzNbJ6Z/ZeZ7TSzX5rZxVXGmJl9y8yeNLNtZnZyMeVKU7Rut0jHqGc99wPA59z9UTM7DNhqZg+4+44JY94PvK386zTg2+XfJRZat1uko9T85O7uz7r7o+Wf/wjsBCq/WPRs4BYveQToNbOjW16tNE7rdot0lKBr7mY2H1gMbKl4qx/YNeH1MFP/A4CZXWBmQ2Y2NDo6GlapNEfrdot0lLqbu5m9Cfg+8Fl3f7Hy7Sp/ZMqKZO5+nbsPuvtgX19fWKXSHK3bLdJR6mruZtZNqbHf5u53VBkyDMyb8HoAeKb58qRltG63SEepJy1jwI3ATne/Zpphm4GPlVMz7wb2uvuzLaxTmqV1u0U6Sj1pmSXAR4HtZvZ4edsXgWMB3P07wD3AMuBJ4E/AJ1pfqjRt0So1c5EOUbO5u/tPqH5NfeIYBy5sVVEiItIcPaEqIpIhNXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIasFFFvw19sNgr8ri1/eW1HAc+3u4gCaX7pynluoPnV4y/dvebiXG1r7jEzsyF3H2x3HUXR/NKV89xA82slXZYREcmQmruISIbU3Ku7rt0FFEzzS1fOcwPNr2V0zV1EJEP65C4ikqGObu5m1mVmj5nZ3VXeW2tmo2b2ePnX37ejxmaY2dNmtr1c/1CV983MvmVmT5rZNjM7uR11NqKOub3HzPZOOH9JfeWUmfWa2UYz+5WZ7TSz0yveT/bcQV3zS/b8mdnCCXU/bmYvmtlnK8YUfv7q+bKOnF0M7AQOn+b977n7Z2axniL8tbtPl6t9P/C28q/TgG+Xf0/FTHMD+LG7L5+1alrrm8B97n6umf058IaK91M/d7XmB4meP3d/AngnlD5AAiPADyqGFX7+OvaTu5kNAB8Abmh3LW10NnCLlzwC9JrZ0e0uqtOZ2eHAGZS+3hJ3f9Xd91QMS/bc1Tm/XJwF/MbdKx/YLPz8dWxzB74BfB54bYYxHyr/k2mjmc2bYVysHPhPM9tqZhdUeb8f2DXh9XB5WwpqzQ3gdDP7hZnda2YnzGZxTToeGAX+tXzZ8AYze2PFmJTPXT3zg3TP30SrgfVVthd+/jqyuZvZcuA5d986w7C7gPnuvgj4IXDzrBTXWkvc/WRK/wS80MzOqHi/2tcnphKfqjW3Ryk9pv1XwL8Am2a7wCYcApwMfNvdFwP/B3yhYkzK566e+aV8/gAoX25aAfx7tberbGvp+evI5k7pS79XmNnTwO3AmWZ268QB7r7b3V8pv7weOGV2S2yeuz9T/v05Stf8Tq0YMgxM/BfJAPDM7FTXnFpzc/cX3f2l8s/3AN1mdtSsF9qYYWDY3beUX2+k1AwrxyR57qhjfomfv3HvBx519/+t8l7h568jm7u7X+buA+4+n9I/mx5y949MHFNx/WsFpRuvyTCzN5rZYeM/A38L/HfFsM3Ax8p37t8N7HX3Z2e51GD1zM3M3mJmVv75VEr/W98927U2wt3/B9hlZgvLm84CdlQMS/LcQX3zS/n8TbCG6pdkYBbOX6enZSYxs3XAkLtvBi4ysxXAAeAFYG07a2vAXwA/KP//4xDg39z9PjP7RwB3/w5wD7AMeBL4E/CJNtUaqp65nQt82swOAPuA1Z7WE3v/BNxW/qf9b4FPZHLuxtWaX9Lnz8zeAPwN8A8Tts3q+dMTqiIiGerIyzIiIrlTcxcRyZCau4hIhtTcRUQypOYuIpIhNXcRkQypuYuIZEjNXUQkQ/8PK3ebjfkeaxkAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(X[:50,0],X[:50,1], label='0')\n",
    "plt.scatter(X[50:,0],X[50:,1], label='1')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SVM:\n",
    "    # max_iter：迭代次数，kernel：所用的核函数，默认为线性SVM\n",
    "    def __init__(self, max_iter=100, kernel='linear'):\n",
    "        self.max_iter = max_iter\n",
    "        self._kernel = kernel\n",
    "    \n",
    "    # 初始化参数\n",
    "    def init_args(self, features, labels):\n",
    "        # m：数据总条数，n：特征值个数\n",
    "        self.m, self.n = features.shape\n",
    "        # X：特征向量\n",
    "        self.X = features\n",
    "        # Y：分类向量\n",
    "        self.Y = labels\n",
    "        # b：截距\n",
    "        self.b = 0.0\n",
    "        \n",
    "        # alpha：拉格朗日乘子向量\n",
    "        self.alpha = np.ones(self.m)\n",
    "        # E：每条数据的误差\n",
    "        self.E = [self._E(i) for i in range(self.m)]\n",
    "        # C：惩罚参数\n",
    "        self.C = 1.0\n",
    "    \n",
    "    # 判断分类\n",
    "    def _KKT(self, i):\n",
    "        y_g = self._g(i)*self.Y[i]\n",
    "        if self.alpha[i] == 0:\n",
    "            return y_g >= 1\n",
    "        elif 0 < self.alpha[i] < self.C:\n",
    "            return y_g == 1\n",
    "        else:\n",
    "            return y_g <= 1\n",
    "    \n",
    "    # g(x)预测值，输入X[i]\n",
    "    def _g(self, i):\n",
    "        r = self.b\n",
    "        for j in range(self.m):\n",
    "            r += self.alpha[j]*self.Y[j]*self.kernel(self.X[i], self.X[j])\n",
    "        return r\n",
    "    \n",
    "    # 核函数\n",
    "    def kernel(self, x1, x2):\n",
    "        if self._kernel == 'linear':\n",
    "            return sum([x1[k]*x2[k] for k in range(self.n)])\n",
    "        elif self._kernel == 'poly':\n",
    "            return (sum([x1[k]*x2[k] for k in range(self.n)]) + 1)**2\n",
    "    \n",
    "        return 0\n",
    "    \n",
    "    # 计算误差，E(x)为g(x)对输入x的预测值和y的差\n",
    "    def _E(self, i):\n",
    "        return self._g(i) - self.Y[i]\n",
    "    \n",
    "    # 计算alpha的值\n",
    "    def _init_alpha(self):\n",
    "        # 符合KKT条件的数据索引，外层循环首先遍历所有满足0<a<C的样本点，检验是否满足KKT\n",
    "        index_list = [i for i in range(self.m) if 0 < self.alpha[i] < self.C]\n",
    "        # 不符合KKT条件的数据索引\n",
    "        non_satisfy_list = [i for i in range(self.m) if i not in index_list]\n",
    "        index_list.extend(non_satisfy_list)\n",
    "        \n",
    "        for i in index_list:\n",
    "            if self._KKT(i):\n",
    "                continue\n",
    "            \n",
    "            E1 = self.E[i]\n",
    "            # 如果E1是+，选择最小的；如果E1是负的，选择最大的\n",
    "            if E1 >= 0:\n",
    "                j = min(range(self.m), key=lambda x: self.E[x])\n",
    "            else:\n",
    "                j = max(range(self.m), key=lambda x: self.E[x])\n",
    "            return i, j\n",
    "    \n",
    "    # 返回最大值\n",
    "    def _compare(self, _alpha, L, H):\n",
    "        if _alpha > H:\n",
    "            return H\n",
    "        elif _alpha < L:\n",
    "            return L\n",
    "        else:\n",
    "            return _alpha      \n",
    "    \n",
    "    def fit(self, features, labels):\n",
    "\n",
    "        self.init_args(features, labels)\n",
    "        \n",
    "        for t in range(self.max_iter):\n",
    "            # 训练数据\n",
    "            i1, i2 = self._init_alpha()\n",
    "            \n",
    "            # 判断边界数据\n",
    "            if self.Y[i1] == self.Y[i2]:\n",
    "                L = max(0, self.alpha[i1]+self.alpha[i2]-self.C)\n",
    "                H = min(self.C, self.alpha[i1]+self.alpha[i2])\n",
    "            else:\n",
    "                L = max(0, self.alpha[i2]-self.alpha[i1])\n",
    "                H = min(self.C, self.C+self.alpha[i2]-self.alpha[i1])\n",
    "                \n",
    "            E1 = self.E[i1]\n",
    "            E2 = self.E[i2]\n",
    "            eta = self.kernel(self.X[i1], self.X[i1]) + self.kernel(self.X[i2], self.X[i2]) - 2*self.kernel(self.X[i1], self.X[i2])\n",
    "            if eta <= 0:\n",
    "                # print('eta <= 0')\n",
    "                continue\n",
    "                \n",
    "            alpha2_new_unc = self.alpha[i2] + self.Y[i2] * (E1 - E2) / eta\n",
    "            alpha2_new = self._compare(alpha2_new_unc, L, H)\n",
    "            \n",
    "            alpha1_new = self.alpha[i1] + self.Y[i1] * self.Y[i2] * (self.alpha[i2] - alpha2_new)\n",
    "            \n",
    "            b1_new = -E1 - self.Y[i1] * self.kernel(self.X[i1], self.X[i1]) * (alpha1_new-self.alpha[i1]) - self.Y[i2] * self.kernel(self.X[i2], self.X[i1]) * (alpha2_new-self.alpha[i2])+ self.b \n",
    "            b2_new = -E2 - self.Y[i1] * self.kernel(self.X[i1], self.X[i2]) * (alpha1_new-self.alpha[i1]) - self.Y[i2] * self.kernel(self.X[i2], self.X[i2]) * (alpha2_new-self.alpha[i2])+ self.b \n",
    "            \n",
    "            # 取新截距\n",
    "            if 0 < alpha1_new < self.C:\n",
    "                b_new = b1_new\n",
    "            elif 0 < alpha2_new < self.C:\n",
    "                b_new = b2_new\n",
    "            else:\n",
    "                # 选择新截距的中点\n",
    "                b_new = (b1_new + b2_new) / 2\n",
    "            \n",
    "            # 更新参数\n",
    "            self.alpha[i1] = alpha1_new\n",
    "            self.alpha[i2] = alpha2_new\n",
    "            self.b = b_new\n",
    "            \n",
    "            self.E[i1] = self._E(i1)\n",
    "            self.E[i2] = self._E(i2)\n",
    "        return 'train done!'\n",
    "            \n",
    "    # 根据模型预测分类\n",
    "    def predict(self, data):\n",
    "        r = self.b\n",
    "        for i in range(self.m):\n",
    "            r += self.alpha[i] * self.Y[i] * self.kernel(data, self.X[i])\n",
    "            \n",
    "        return 1 if r > 0 else -1\n",
    "    \n",
    "    # 得到正确分类率\n",
    "    def score(self, X_test, y_test):\n",
    "        right_count = 0\n",
    "        for i in range(len(X_test)):\n",
    "            result = self.predict(X_test[i])\n",
    "            if result == y_test[i]:\n",
    "                right_count += 1\n",
    "        return right_count / len(X_test)\n",
    "    \n",
    "    # 计算w的值\n",
    "    def _weight(self):\n",
    "        yx = self.Y.reshape(-1, 1)*self.X\n",
    "        self.w = np.dot(yx.T, self.alpha)\n",
    "        return self.w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "svm = SVM(max_iter=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'train done!'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "svm.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "svm.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## sklearn.svm.SVC\n",
    "\n",
    "*(C=1.0, kernel='rbf', degree=3, gamma='auto', coef0=0.0, shrinking=True, probability=False,tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, decision_function_shape=None,random_state=None)*\n",
    "\n",
    "参数：\n",
    "\n",
    "- C：C-SVC的惩罚参数C，默认值是1.0\n",
    "\n",
    "C越大，相当于惩罚松弛变量，希望松弛变量接近0，即对误分类的惩罚增大，趋向于对训练集全分对的情况，这样对训练集测试时准确率很高，但泛化能力弱。C值小，对误分类的惩罚减小，允许容错，将他们当成噪声点，泛化能力较强。\n",
    "\n",
    "- kernel ：核函数，默认是rbf，可以是‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’ \n",
    "    \n",
    "    – 线性：u'v\n",
    "    \n",
    "    – 多项式：(gamma*u'*v + coef0)^degree\n",
    "\n",
    "    – RBF函数：exp(-gamma|u-v|^2)\n",
    "\n",
    "    – sigmoid：tanh(gamma*u'*v + coef0)\n",
    "\n",
    "\n",
    "- degree ：多项式poly函数的维度，默认是3，选择其他核函数时会被忽略。\n",
    "\n",
    "\n",
    "- gamma ： ‘rbf’,‘poly’ 和‘sigmoid’的核函数参数。默认是’auto’，则会选择1/n_features\n",
    "\n",
    "\n",
    "- coef0 ：核函数的常数项。对于‘poly’和 ‘sigmoid’有用。\n",
    "\n",
    "\n",
    "- probability ：是否采用概率估计？.默认为False\n",
    "\n",
    "\n",
    "- shrinking ：是否采用shrinking heuristic方法，默认为true\n",
    "\n",
    "\n",
    "- tol ：停止训练的误差值大小，默认为1e-3\n",
    "\n",
    "\n",
    "- cache_size ：核函数cache缓存大小，默认为200\n",
    "\n",
    "\n",
    "- class_weight ：类别的权重，字典形式传递。设置第几类的参数C为weight*C(C-SVC中的C)\n",
    "\n",
    "\n",
    "- verbose ：允许冗余输出？\n",
    "\n",
    "\n",
    "- max_iter ：最大迭代次数。-1为无限制。\n",
    "\n",
    "\n",
    "- decision_function_shape ：‘ovo’, ‘ovr’ or None, default=None3\n",
    "\n",
    "\n",
    "- random_state ：数据洗牌时的种子值，int值\n",
    "\n",
    "\n",
    "主要调节的参数有：C、kernel、degree、gamma、coef0。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n",
       "  decision_function_shape='ovr', degree=3, gamma='auto', kernel='rbf',\n",
       "  max_iter=-1, probability=False, random_state=None, shrinking=True,\n",
       "  tol=0.001, verbose=False)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.svm import SVC\n",
    "clf = SVC(gamma='auto')\n",
    "clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.score(X_test, y_test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
